import os
import requests
import types
import json
import csv
import pickle

import numpy as np
from sklearn.preprocessing import label_binarize
import scipy.io

import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.datasets import StochasticBlockModelDataset
from torch_geometric.utils import to_undirected, add_remaining_self_loops, convert, subgraph
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

from data_utils import get_mask, set_train_val_test_split

from nifty.utils import load_credit, load_german, load_bail

DATA_PATH = 'data'

# +
def keep_only_largest_connected_component(dataset):
    lcc = get_largest_connected_component(dataset)

    lcc_edge_index, lcc_edge_attr = subgraph(lcc, dataset.data.edge_index, dataset.data.edge_attr.reshape(-1, 1), 
            relabel_nodes=True, num_nodes=dataset.num_nodes)

    data = Data(
        x=dataset.data.x[lcc],
        edge_index=lcc_edge_index,
        y=dataset.data.y[lcc],
        edge_attr = lcc_edge_attr
    )
    dataset.data = data

    return dataset

def get_component(dataset: InMemoryDataset, start: int = 0) -> set:
    visited_nodes = set()
    queued_nodes = set([start])
    row, col = dataset.data.edge_index.numpy()
    while queued_nodes:
        current_node = queued_nodes.pop()
        visited_nodes.update([current_node])
        neighbors = col[np.where(row == current_node)[0]]
        neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes]
        queued_nodes.update(neighbors)
    return visited_nodes


def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray:
    remaining_nodes = set(range(dataset.data.x.shape[0]))
    comps = []
    while remaining_nodes:
        start = min(remaining_nodes)
        comp = get_component(dataset, start)
        comps.append(comp)
        remaining_nodes = remaining_nodes.difference(comp)
    return np.array(list(comps[np.argmax(list(map(len, comps)))]))


# -

def get_dataset(name: str, seed: int, use_lcc: bool = False, homophily=None):
    path = os.path.join(DATA_PATH, name)
    evaluator = None
    
    sens = None
    if name in ['credit', 'german', 'bail']:
        dataset, sens = load_fairness_datasets(name)
    elif name.startswith('sbm'):
        dataset, sens = load_sbm(name, seed)
    elif name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = Planetoid(path, name)
    elif name in ['Computers', 'Photo']:
        dataset = Amazon(path, name)
    elif name == 'CoauthorCS':
        dataset = Coauthor(path, 'CS')
    elif name == 'CoauthorPhysics':
        dataset = Coauthor(path, 'Physics')
    elif name in ['OGBN-Arxiv', 'OGBN-Products']:
        dataset = PygNodePropPredDataset(name=name.lower(), transform=transforms.ToSparseTensor(), root=path)
        evaluator = Evaluator(name=name.lower())
        use_lcc = False
    elif name == "Twitch":
        dataset = load_twitch_dataset("DE")
        use_lcc = False
    elif name == "Deezer-Europe":
        dataset = DeezerEurope(path)
        use_lcc = False
    elif name == "FB100":
        sub_dataname = 'Penn94'
        dataset = load_fb100_dataset(sub_dataname)
        use_lcc = False
    elif name == "Actor":
        dataset = Actor(path)
        use_lcc = False
    elif name == 'Syn-Cora':
        dataset = load_syn_cora(homophily)
    elif name == 'MixHopSynthetic':
        dataset = MixHopSyntheticDataset(path, homophily=homophily)
    else:
        raise Exception('Unknown dataset.')

    if use_lcc:
        dataset = keep_only_largest_connected_component(dataset)
    
    # Make graph undirected so that we have edges for both directions and add self loops
    dataset.data.edge_index = to_undirected(dataset.data.edge_index)
    dataset.data.edge_index, _ = add_remaining_self_loops(dataset.data.edge_index, num_nodes=dataset.data.x.shape[0])
    print("Data: ", dataset.data)    
    
    return dataset, sens, evaluator

def load_sbm(name, seed, num_nodes=1000, num_features=10):
    # format: sbm - Q_block_rate - intra-link rate - inter-link rate
    Q_rate, intra_rate, inter_rate = [float(param) for param in name.split('-')[1:]]
    Q_size = int(Q_rate / 100 * num_nodes)
    
    # undirected
    # labels are sensitive attributes
    dataset = StochasticBlockModelDataset(name, block_sizes = [Q_size, num_nodes - Q_size], \
                                          edge_probs = [[intra_rate / 100, inter_rate / 100], [inter_rate / 100, intra_rate / 100]], \
                                          num_channels = num_features, \
                                          is_undirected = True)
    
    sens = dataset.data.y
    edge_index, _ = add_remaining_self_loops(dataset.data.edge_index, num_nodes=dataset.data.x.shape[0])
    data = Data(
            x=dataset.data.x,
            edge_index=edge_index,
            edge_attr=torch.ones(edge_index.size(1)),
            y=dataset.data.y
        )
    data = set_train_val_test_split(seed, data, name)

    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1
    
    return dataset, sens


# +
def load_fairness_datasets(name):
    # Load credit_scoring dataset
    if name == 'credit':
        sens_attr = "Age"  # column number after feature process is 1
        sens_idx = 1
        predict_attr = 'NoDefaultNextMonth'
        label_number = 6000
        path_credit = "nifty/dataset/credit"
        adj, features, labels, idx_train, idx_val, idx_test, sens = load_credit(name, sens_attr,
                                                                                predict_attr, path=path_credit,
                                                                                label_number=label_number)
        
        # norm_features = feature_norm(features)
#         norm_features[:, sens_idx] = features[:, sens_idx]
#         features = norm_features

    # Load german dataset
    elif name == 'german':
        sens_attr = "Gender"  # column number after feature process is 0
        sens_idx = 0
        predict_attr = "GoodCustomer"
        label_number = 100
        path_german = "nifty/dataset/german"
        adj, features, labels, idx_train, idx_val, idx_test, sens = load_german(name, sens_attr,
                                                                                predict_attr, path=path_german,
                                                                                label_number=label_number,
                                                                                    )
    # Load bail dataset
    elif name == 'bail':
        sens_attr = "WHITE"  # column number after feature process is 0
        sens_idx = 0
        predict_attr = "RECID"
        label_number = 100
        path_bail = "nifty/dataset/bail"
        adj, features, labels, idx_train, idx_val, idx_test, sens = load_bail(name, sens_attr, 
                                                                                predict_attr, path=path_bail,
                                                                                label_number=label_number,
                                                                                )
        
#         norm_features = feature_norm(features)
#         norm_features[:, sens_idx] = features[:, sens_idx]
#         features = norm_features
    
    edge_index = convert.from_scipy_sparse_matrix(adj)[0]
    # don't provide model with access to sensitive attributes
    features = torch.cat((features[:, :sens_idx], features[:, sens_idx+1:]), dim=1)
    sens = torch.LongTensor(sens)
    
    data = Data(
            x=features,
            edge_index=edge_index,
            y=labels.long(),
            edge_attr=torch.ones(edge_index.size(1)),
#             train_mask = idx_train,
#             val_mask = idx_val,
#             test_mask = idx_test
        )
    num_nodes = features.size(0)
    data.train_mask = get_mask(idx_train, num_nodes)
    data.val_mask = get_mask(idx_val, num_nodes)
    data.test_mask = get_mask(idx_test, num_nodes)

    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1

    return dataset, sens


# -

def load_fb100(filename):
    if not os.path.exists(f"{DATA_PATH}/FB100/"):
        os.mkdir(f"{DATA_PATH}/FB100/")

    if not os.path.isfile(f"{DATA_PATH}/FB100/{filename}"):
        url = f"https://github.com/CUAI/Non-Homophily-Benchmarks/raw/5b2ffa908274f9929b95402b71c9b645928f292c/data/facebook100/{filename}.mat"
        r = requests.get(url, allow_redirects=True)
        with open(f"{DATA_PATH}/FB100/{filename}.mat", "wb") as f:
            f.write(r.content)

    mat = scipy.io.loadmat(DATA_PATH + '/FB100/' + filename + '.mat')
    A = mat['A']
    metadata = mat['local_info']
    return A, metadata

def load_fb100_dataset(filename):
    A, metadata = load_fb100(filename)
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    metadata = metadata.astype(np.int)
    label = metadata[:, 1] - 1  # gender label, -1 means unlabeled

    # make features into one-hot encodings
    feature_vals = np.hstack(
        (np.expand_dims(metadata[:, 0], 1), metadata[:, 2:]))
    features = np.empty((A.shape[0], 0))
    for col in range(feature_vals.shape[1]):
        feat_col = feature_vals[:, col]
        feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col))
        features = np.hstack((features, feat_onehot))

    node_feat = torch.tensor(features, dtype=torch.float)

    data = Data(
            x=node_feat,
            edge_index=edge_index,
            y=torch.tensor(label),
        )

    # This allows to just have a general object to which we can assign fields
    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1

    return dataset

def load_twitch_dataset(lang):
    A, label, features = load_twitch(lang)
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    node_feat = torch.tensor(features, dtype=torch.float)
    
    data = Data(
            x=node_feat,
            edge_index=edge_index,
            y=torch.tensor(label),
        )

    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1

    return dataset

def load_twitch(lang):
    assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset'

    if not os.path.exists(f"{DATA_PATH}/Twitch/"):
        os.mkdir(f"{DATA_PATH}/Twitch/")

    files = ["musae_DE_target.csv", "musae_DE_edges.csv", "musae_DE_features.json"]

    for file in files:
        if not os.path.isfile(f"{DATA_PATH}/Twitch/{file}"):
            url = f"https://github.com/CUAI/Non-Homophily-Benchmarks/raw/5b2ffa908274f9929b95402b71c9b645928f292c/data/twitch/DE/{file}"
            r = requests.get(url, allow_redirects=True)
            with open(f"{DATA_PATH}/Twitch/{file}", "wb") as f:
                f.write(r.content)

    label = []
    node_ids = []
    src = []
    targ = []
    uniq_ids = set()
    with open(f"{DATA_PATH}/Twitch/musae_{lang}_target.csv", 'r') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            node_id = int(row[5])
            # handle FR case of non-unique rows
            if node_id not in uniq_ids:
                uniq_ids.add(node_id)
                label.append(int(row[2]=="True"))
                node_ids.append(int(row[5]))

    node_ids = np.array(node_ids, dtype=np.int)
    with open(f"{DATA_PATH}/Twitch/musae_{lang}_edges.csv", 'r') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            src.append(int(row[0]))
            targ.append(int(row[1]))
    with open(f"{DATA_PATH}/Twitch/musae_{lang}_features.json", 'r') as f:
        j = json.load(f)
    src = np.array(src)
    targ = np.array(targ)
    label = np.array(label)
    inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)}
    reorder_node_ids = np.zeros_like(node_ids)
    for i in range(label.shape[0]):
        reorder_node_ids[i] = inv_node_ids[i]
    
    n = label.shape[0]
    A = scipy.sparse.csr_matrix((np.ones(len(src)), 
                                 (np.array(src), np.array(targ))),
                                shape=(n,n))
    features = np.zeros((n,3170))
    for node, feats in j.items():
        if int(node) >= n:
            continue
        features[int(node), np.array(feats, dtype=int)] = 1
    features = features[:, np.sum(features, axis=0) != 0] # remove zero cols
    new_label = label[reorder_node_ids]
    label = new_label
    
    return A, label, features

def load_syn_cora(homophily):
    if homophily is None:
        raise ValueError('Specify a level of homophily.')

    data = pickle.load(open(f"data/syn-cora/{homophily}-0.p", "rb"))
    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1
    return dataset
